-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][x86vector] AVX512-BF16 Convert packed F32 to BF16 #125685
Conversation
Adds AVX512 bf16 conversion from packed f32 to bf16 elements. Tests are slightly refactored to better follow file's convention.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Adam Siemieniuk (adam-smnk) ChangesAdds AVX512 bf16 conversion from packed f32 to bf16 elements. Tests are slightly refactored to better follow file's convention. Full diff: https://github.com/llvm/llvm-project/pull/125685.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 16181d7e760db5..566013e73f4b89 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -341,6 +341,46 @@ def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
}
+//----------------------------------------------------------------------------//
+// Convert packed F32 to packed BF16
+//----------------------------------------------------------------------------//
+
+def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
+ AllElementCountsMatch<["a", "dst"]>]> {
+ let summary = "Convert packed F32 to packed BF16 Data.";
+ let description = [{
+ The `convert_f32_to_bf16` op is an AVX512-BF16 specific op that can lower
+ to the proper LLVMAVX512BF16 operation `llvm.cvtneps2bf16` depending on
+ the width of MLIR vectors it is applied to.
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed single-precision (32-bit) floating-point elements in `a` to
+ packed BF16 (16-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+ ```
+ }];
+ let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a);
+ let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict `:` type($a) `->` type($dst)";
+}
+
+def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
+ /*extension=*/"bf16"> {
+ let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
+ let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
+}
+
+def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
+ /*extension=*/"bf16"> {
+ let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
+ let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
+}
+
//===----------------------------------------------------------------------===//
// AVX op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 260ac9ce589a38..f1fbb39b97fc49 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -131,6 +131,39 @@ struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
}
};
+struct CvtPackedF32ToBF16Conversion
+ : public ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op> {
+ using ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(CvtPackedF32ToBF16Op op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto typeA = dyn_cast<VectorType>(op.getA().getType());
+ unsigned elemBitWidth = typeA.getElementTypeBitWidth();
+ unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
+
+ auto opType = op.getDst().getType();
+ auto opA = op.getA();
+
+ switch (opBitWidth) {
+ case 256: {
+ rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps256IntrOp>(op, opType, opA);
+ break;
+ }
+ case 512: {
+ rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps512IntrOp>(op, opType, opA);
+ break;
+ }
+ default: {
+ return rewriter.notifyMatchFailure(
+ op, "unsupported AVX512-BF16 packed f32 to bf16 variant");
+ }
+ }
+
+ return success();
+ }
+};
+
struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
@@ -202,8 +235,10 @@ using Registry = RegistryImpl<
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Registry::registerPatterns(converter, patterns);
- patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
- DotOpConversion>(converter);
+ patterns
+ .add<MaskCompressOpConversion, DotBF16OpConversion,
+ CvtPackedF32ToBF16Conversion, RsqrtOpConversion, DotOpConversion>(
+ converter);
}
void mlir::configureX86VectorLegalizeForExportTarget(
@@ -215,6 +250,9 @@ void mlir::configureX86VectorLegalizeForExportTarget(
target.addLegalOp<DotBF16Ps256IntrOp>();
target.addLegalOp<DotBF16Ps512IntrOp>();
target.addIllegalOp<DotBF16Op>();
+ target.addLegalOp<CvtNeF32ToBF16Ps256IntrOp>();
+ target.addLegalOp<CvtNeF32ToBF16Ps512IntrOp>();
+ target.addIllegalOp<CvtPackedF32ToBF16Op>();
target.addLegalOp<RsqrtIntrOp>();
target.addIllegalOp<RsqrtOp>();
target.addLegalOp<DotIntrOp>();
diff --git a/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
new file mode 100644
index 00000000000000..c97c52f01c3b03
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
@@ -0,0 +1,24 @@
+// REQUIRES: target=x86{{.*}}
+
+// RUN: mlir-opt %s \
+// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: llc -mcpu=sapphirerapids | \
+// RUN: FileCheck %s
+
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+ %a: vector<8xf32>) -> vector<8xbf16> {
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+ return %0 : vector<8xbf16>
+}
+// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256:
+// CHECK: vcvtneps2bf16{{.*}}%xmm
+
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+ %a: vector<16xf32>) -> vector<16xbf16> {
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+ return %0 : vector<16xbf16>
+}
+// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512:
+// CHECK: vcvtneps2bf16{{.*}}%ymm
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index ed9177eaec9ce4..59be7dd75b3b0b 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -70,6 +70,24 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
return %0 : vector<16xf32>
}
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+ %a: vector<8xf32>) -> (vector<8xbf16>)
+{
+ // CHECK: x86vector.avx512.intr.cvtneps2bf16.256
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+ return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+ %a: vector<16xf32>) -> (vector<16xbf16>)
+{
+ // CHECK: x86vector.avx512.intr.cvtneps2bf16.512
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+ return %0 : vector<16xbf16>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index cf74a7ee602558..0d00448c63da88 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -74,6 +74,26 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
return %0 : vector<16xf32>
}
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+ %a: vector<8xf32>) -> (vector<8xbf16>)
+{
+ // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
+ // CHECK-SAME: vector<8xf32> -> vector<8xbf16>
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+ return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+ %a: vector<16xf32>) -> (vector<16xbf16>)
+{
+ // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
+ // CHECK-SAME: vector<16xf32> -> vector<16xbf16>
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+ return %0 : vector<16xbf16>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 1df03f10c93214..db1c10cd5cd37a 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -62,37 +62,57 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
// CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128(
- %arg0: vector<4xf32>, %arg1: vector<8xbf16>, %arg2: vector<8xbf16>
+ %src: vector<4xf32>, %a: vector<8xbf16>, %b: vector<8xbf16>
) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
- %0 = "x86vector.avx512.intr.dpbf16ps.128"(%arg0, %arg1, %arg2)
+ %0 = "x86vector.avx512.intr.dpbf16ps.128"(%src, %a, %b)
: (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32>
llvm.return %0 : vector<4xf32>
}
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256(
- %arg0: vector<8xf32>, %arg1: vector<16xbf16>, %arg2: vector<16xbf16>
+ %src: vector<8xf32>, %a: vector<16xbf16>, %b: vector<16xbf16>
) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
- %0 = "x86vector.avx512.intr.dpbf16ps.256"(%arg0, %arg1, %arg2)
+ %0 = "x86vector.avx512.intr.dpbf16ps.256"(%src, %a, %b)
: (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32>
llvm.return %0 : vector<8xf32>
}
// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512(
- %arg0: vector<16xf32>, %arg1: vector<32xbf16>, %arg2: vector<32xbf16>
+ %src: vector<16xf32>, %a: vector<32xbf16>, %b: vector<32xbf16>
) -> vector<16xf32>
{
// CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
- %0 = "x86vector.avx512.intr.dpbf16ps.512"(%arg0, %arg1, %arg2)
+ %0 = "x86vector.avx512.intr.dpbf16ps.512"(%src, %a, %b)
: (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32>
llvm.return %0 : vector<16xf32>
}
+// CHECK-LABEL: define <8 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_256
+llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
+ %a: vector<8xf32>) -> vector<8xbf16>
+{
+ // CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
+ %0 = "x86vector.avx512.intr.cvtneps2bf16.256"(%a)
+ : (vector<8xf32>) -> vector<8xbf16>
+ llvm.return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: define <16 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_512
+llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
+ %a: vector<16xf32>) -> vector<16xbf16>
+{
+ // CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
+ %0 = "x86vector.avx512.intr.cvtneps2bf16.512"(%a)
+ : (vector<16xf32>) -> vector<16xbf16>
+ llvm.return %0 : vector<16xbf16>
+}
+
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
@@ -103,11 +123,11 @@ llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256
llvm.func @LLVM_x86_avx_dp_ps_256(
- %arg0: vector<8xf32>, %arg1: vector<8xf32>
+ %a: vector<8xf32>, %b: vector<8xf32>
) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
- %0 = llvm.mlir.constant(-1 : i8) : i8
- %1 = "x86vector.avx.intr.dp.ps.256"(%arg0, %arg1, %0) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
+ %c = llvm.mlir.constant(-1 : i8) : i8
+ %1 = "x86vector.avx.intr.dp.ps.256"(%a, %b, %c) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
llvm.return %1 : vector<8xf32>
}
|
@llvm/pr-subscribers-mlir-llvm Author: Adam Siemieniuk (adam-smnk) ChangesAdds AVX512 bf16 conversion from packed f32 to bf16 elements. Tests are slightly refactored to better follow file's convention. Full diff: https://github.com/llvm/llvm-project/pull/125685.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 16181d7e760db5f..566013e73f4b890 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -341,6 +341,46 @@ def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
}
+//----------------------------------------------------------------------------//
+// Convert packed F32 to packed BF16
+//----------------------------------------------------------------------------//
+
+def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
+ AllElementCountsMatch<["a", "dst"]>]> {
+ let summary = "Convert packed F32 to packed BF16 Data.";
+ let description = [{
+ The `convert_f32_to_bf16` op is an AVX512-BF16 specific op that can lower
+ to the proper LLVMAVX512BF16 operation `llvm.cvtneps2bf16` depending on
+ the width of MLIR vectors it is applied to.
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed single-precision (32-bit) floating-point elements in `a` to
+ packed BF16 (16-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+ ```
+ }];
+ let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a);
+ let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict `:` type($a) `->` type($dst)";
+}
+
+def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
+ /*extension=*/"bf16"> {
+ let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
+ let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
+}
+
+def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
+ /*extension=*/"bf16"> {
+ let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
+ let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
+}
+
//===----------------------------------------------------------------------===//
// AVX op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 260ac9ce589a38f..f1fbb39b97fc498 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -131,6 +131,39 @@ struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
}
};
+struct CvtPackedF32ToBF16Conversion
+ : public ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op> {
+ using ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(CvtPackedF32ToBF16Op op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto typeA = dyn_cast<VectorType>(op.getA().getType());
+ unsigned elemBitWidth = typeA.getElementTypeBitWidth();
+ unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
+
+ auto opType = op.getDst().getType();
+ auto opA = op.getA();
+
+ switch (opBitWidth) {
+ case 256: {
+ rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps256IntrOp>(op, opType, opA);
+ break;
+ }
+ case 512: {
+ rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps512IntrOp>(op, opType, opA);
+ break;
+ }
+ default: {
+ return rewriter.notifyMatchFailure(
+ op, "unsupported AVX512-BF16 packed f32 to bf16 variant");
+ }
+ }
+
+ return success();
+ }
+};
+
struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
@@ -202,8 +235,10 @@ using Registry = RegistryImpl<
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Registry::registerPatterns(converter, patterns);
- patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
- DotOpConversion>(converter);
+ patterns
+ .add<MaskCompressOpConversion, DotBF16OpConversion,
+ CvtPackedF32ToBF16Conversion, RsqrtOpConversion, DotOpConversion>(
+ converter);
}
void mlir::configureX86VectorLegalizeForExportTarget(
@@ -215,6 +250,9 @@ void mlir::configureX86VectorLegalizeForExportTarget(
target.addLegalOp<DotBF16Ps256IntrOp>();
target.addLegalOp<DotBF16Ps512IntrOp>();
target.addIllegalOp<DotBF16Op>();
+ target.addLegalOp<CvtNeF32ToBF16Ps256IntrOp>();
+ target.addLegalOp<CvtNeF32ToBF16Ps512IntrOp>();
+ target.addIllegalOp<CvtPackedF32ToBF16Op>();
target.addLegalOp<RsqrtIntrOp>();
target.addIllegalOp<RsqrtOp>();
target.addLegalOp<DotIntrOp>();
diff --git a/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
new file mode 100644
index 000000000000000..c97c52f01c3b033
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
@@ -0,0 +1,24 @@
+// REQUIRES: target=x86{{.*}}
+
+// RUN: mlir-opt %s \
+// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: llc -mcpu=sapphirerapids | \
+// RUN: FileCheck %s
+
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+ %a: vector<8xf32>) -> vector<8xbf16> {
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+ return %0 : vector<8xbf16>
+}
+// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256:
+// CHECK: vcvtneps2bf16{{.*}}%xmm
+
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+ %a: vector<16xf32>) -> vector<16xbf16> {
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+ return %0 : vector<16xbf16>
+}
+// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512:
+// CHECK: vcvtneps2bf16{{.*}}%ymm
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index ed9177eaec9ce4a..59be7dd75b3b0b8 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -70,6 +70,24 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
return %0 : vector<16xf32>
}
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+ %a: vector<8xf32>) -> (vector<8xbf16>)
+{
+ // CHECK: x86vector.avx512.intr.cvtneps2bf16.256
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+ return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+ %a: vector<16xf32>) -> (vector<16xbf16>)
+{
+ // CHECK: x86vector.avx512.intr.cvtneps2bf16.512
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+ return %0 : vector<16xbf16>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index cf74a7ee602558f..0d00448c63da889 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -74,6 +74,26 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
return %0 : vector<16xf32>
}
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+ %a: vector<8xf32>) -> (vector<8xbf16>)
+{
+ // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
+ // CHECK-SAME: vector<8xf32> -> vector<8xbf16>
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+ return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+ %a: vector<16xf32>) -> (vector<16xbf16>)
+{
+ // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
+ // CHECK-SAME: vector<16xf32> -> vector<16xbf16>
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+ return %0 : vector<16xbf16>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 1df03f10c93214a..db1c10cd5cd37a2 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -62,37 +62,57 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
// CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128(
- %arg0: vector<4xf32>, %arg1: vector<8xbf16>, %arg2: vector<8xbf16>
+ %src: vector<4xf32>, %a: vector<8xbf16>, %b: vector<8xbf16>
) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
- %0 = "x86vector.avx512.intr.dpbf16ps.128"(%arg0, %arg1, %arg2)
+ %0 = "x86vector.avx512.intr.dpbf16ps.128"(%src, %a, %b)
: (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32>
llvm.return %0 : vector<4xf32>
}
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256(
- %arg0: vector<8xf32>, %arg1: vector<16xbf16>, %arg2: vector<16xbf16>
+ %src: vector<8xf32>, %a: vector<16xbf16>, %b: vector<16xbf16>
) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
- %0 = "x86vector.avx512.intr.dpbf16ps.256"(%arg0, %arg1, %arg2)
+ %0 = "x86vector.avx512.intr.dpbf16ps.256"(%src, %a, %b)
: (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32>
llvm.return %0 : vector<8xf32>
}
// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512(
- %arg0: vector<16xf32>, %arg1: vector<32xbf16>, %arg2: vector<32xbf16>
+ %src: vector<16xf32>, %a: vector<32xbf16>, %b: vector<32xbf16>
) -> vector<16xf32>
{
// CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
- %0 = "x86vector.avx512.intr.dpbf16ps.512"(%arg0, %arg1, %arg2)
+ %0 = "x86vector.avx512.intr.dpbf16ps.512"(%src, %a, %b)
: (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32>
llvm.return %0 : vector<16xf32>
}
+// CHECK-LABEL: define <8 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_256
+llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
+ %a: vector<8xf32>) -> vector<8xbf16>
+{
+ // CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
+ %0 = "x86vector.avx512.intr.cvtneps2bf16.256"(%a)
+ : (vector<8xf32>) -> vector<8xbf16>
+ llvm.return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: define <16 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_512
+llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
+ %a: vector<16xf32>) -> vector<16xbf16>
+{
+ // CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
+ %0 = "x86vector.avx512.intr.cvtneps2bf16.512"(%a)
+ : (vector<16xf32>) -> vector<16xbf16>
+ llvm.return %0 : vector<16xbf16>
+}
+
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
@@ -103,11 +123,11 @@ llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256
llvm.func @LLVM_x86_avx_dp_ps_256(
- %arg0: vector<8xf32>, %arg1: vector<8xf32>
+ %a: vector<8xf32>, %b: vector<8xf32>
) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
- %0 = llvm.mlir.constant(-1 : i8) : i8
- %1 = "x86vector.avx.intr.dp.ps.256"(%arg0, %arg1, %0) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
+ %c = llvm.mlir.constant(-1 : i8) : i8
+ %1 = "x86vector.avx.intr.dp.ps.256"(%a, %b, %c) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
llvm.return %1 : vector<8xf32>
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as before. Not the best design, but following the existing pattern. This is helping us understand the design and will be target of an RFC soon on how to simplify CPU dialect lowering without needing a full blown VM dialect.
Adds AVX512 bf16 conversion from packed f32 to bf16 elements. Tests are slightly refactored to better follow file's convention.
Adds AVX512 bf16 conversion from packed f32 to bf16 elements.
Tests are slightly refactored to better follow file's convention.